{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    },
    "colab": {
      "name": "Keypoint_model_training.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "igMyGnjE9hEp"
      },
      "source": [
        "import csv\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "RANDOM_SEED = 42"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t2HDvhIu9hEr"
      },
      "source": [
        "# Specify each path"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9NvZP2Zn9hEy"
      },
      "source": [
        "# Specify data paths\n",
        "dataset = 'keypoint.csv'\n",
        "model_save_path = 'keypoint_classifier/keypoint_classifier.hdf5'\n",
        "tflite_save_path = 'keypoint_classifier/keypoint_classifier.tflite'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s5oMH7x19hEz"
      },
      "source": [
        "# Set number of classes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "du4kodXL9hEz"
      },
      "source": [
        "# Change training classes if necessary\n",
        "NUM_CLASSES = 8"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XjnL0uso9hEz"
      },
      "source": [
        "# Dataset reading"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QT5ZqtEz9hE0"
      },
      "source": [
        "X_dataset = np.loadtxt(dataset, delimiter=',', dtype='float32', usecols=list(range(1, (21 * 2) + 1)))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QmoKFsp49hE0"
      },
      "source": [
        "y_dataset = np.loadtxt(dataset, delimiter=',', dtype='int32', usecols=(0))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xQU7JTZ_9hE0"
      },
      "source": [
        "X_train, X_test, y_train, y_test = train_test_split(X_dataset, y_dataset, train_size=0.75, random_state=RANDOM_SEED)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 279
        },
        "id": "xElG5FoPDQO9",
        "outputId": "2ef372ed-62e3-49c1-ad36-a5b5dc76701a"
      },
      "source": [
        "# Classes count\n",
        "counts = np.unique(y_dataset, return_counts=True)\n",
        "df = pd.DataFrame(counts)\n",
        "df.T.plot(kind=\"bar\", stacked=True)\n",
        "print(counts)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32), array([1595, 1663, 1510,  672,  164,  257,  139,  190]))\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD1CAYAAAC87SVQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAU5klEQVR4nO3dfZBd9X3f8ffHCJDtuDwuFGtFJAeZRNDYpgvGQ+uxQ8KD7EFMx3FR0yAMrqatSJw6Uxs7mcFNygxuPaV47HpGNbIh40IIdYrSUmMVO/GkDQ8C22DAmA3Y1mp4WAuMExPMg7/94/4wy7LLavde3V1x3q+ZnT3ne37nnO9K2s89+t1z701VIUnqhlctdgOSpOEx9CWpQwx9SeoQQ1+SOsTQl6QOMfQlqUOWLXYDL+fwww+vVatWLXYbkrRPuf32239QVSMzbVvSob9q1Sp27Nix2G1I0j4lyfdm2+b0jiR1iKEvSR1i6EtShyzpOX1JWizPPPMMExMTPPXUU4vdyqyWL1/O6Ogo+++//x7vY+hL0gwmJiZ43etex6pVq0iy2O28RFWxe/duJiYmWL169R7v5/SOJM3gqaee4rDDDluSgQ+QhMMOO2ze/xMx9CVpFks18J+3kP4MfUlawr70pS9x7LHHcswxx3DppZf2fTzn9IftYwfthWM+MfhjSnqRVRf9r4Ee77uXvmvOMc899xybN29m+/btjI6OcuKJJ3LWWWexdu3aBZ/XK31JWqJuvfVWjjnmGN7whjdwwAEHcM4553D99df3dUxDX5KWqF27drFy5cqfrY+OjrJr166+jmnoS1KHGPqStEStWLGCnTt3/mx9YmKCFStW9HXMOUM/ydYkjyb51rT6byX5dpK7k/yHKfWPJBlPcl+S06fUz2i18SQX9dW1JHXAiSeeyP3338+DDz7I008/zTXXXMNZZ53V1zH35O6dzwOfAq56vpDkncB64E1V9ZMkR7T6WuAc4Djg9cD/SfLGttungV8DJoDbkmyrqnv66l6SXsGWLVvGpz71KU4//XSee+45zj//fI477rj+jjnXgKr6WpJV08r/Cri0qn7Sxjza6uuBa1r9wSTjwElt23hVPQCQ5Jo21tCXtE/Yk1ss94Z169axbt26gR1voXP6bwT+cZJbkvxFkhNbfQWwc8q4iVabrf4SSTYl2ZFkx+Tk5ALbkyTNZKEvzloGHAqcDJwIXJvkDYNoqKq2AFsAxsbGao939EVPkjSnhYb+BPDFqirg1iQ/BQ4HdgErp4wbbTVepi5JGpKFTu/8D+CdAO2J2gOAHwDbgHOSHJhkNbAGuBW4DViTZHWSA+g92but3+YlSfMz55V+kquBdwCHJ5kALga2AlvbbZxPAxvbVf/dSa6l9wTts8DmqnquHedC4EZgP2BrVd29F34eSdLL2JO7dzbMsumfzzL+EuCSGeo3ADfMqztJ0kD5ilxJWqLOP/98jjjiCI4//viBHdO3VpakPTHoOwT34O7A8847jwsvvJBzzz13YKf1Sl+Slqi3v/3tHHrooQM9pqEvSR1i6EtShzinr5n5CmfpFckrfUnqEENfkpaoDRs28La3vY377ruP0dFRrrjiir6P6fSOJO2JRZievPrqqwd+TK/0JalDDH1J6hBDX5I6xNCXpFn03jx46VpIf4a+JM1g+fLl7N69e8kGf1Wxe/duli9fPq/9vHtHkmYwOjrKxMQES/mzupcvX87o6Oi89tmTD1HZCrwbeLSqjp+27XeBTwAjVfWDJAEuB9YBTwLnVdUdbexG4Pfbrv++qq6cV6eSNET7778/q1evXuw2Bm5Ppnc+D5wxvZhkJXAa8P0p5TPpfUTiGmAT8Jk29lB6n7j1VuAk4OIkh/TTuCRp/uYM/ar6GvDYDJsuAz4ETJ3wWg9cVT03AwcnOQo4HdheVY9V1ePAdmZ4IJEk7V0LeiI3yXpgV1V9c9qmFcDOKesTrTZbXZI0RPN+IjfJa4CP0pvaGbgkm+hNDXH00UfvjVNIUmct5Er/F4DVwDeTfBcYBe5I8veBXcDKKWNHW222+ktU1ZaqGquqsZGRkQW0J0mazbxDv6ruqqojqmpVVa2iN1VzQlU9DGwDzk3PycATVfUQcCNwWpJD2hO4p7WaJGmI5gz9JFcDfwUcm2QiyQUvM/wG4AFgHPivwL8GqKrHgD8Ebmtff9BqkqQhmnNOv6o2zLF91ZTlAjbPMm4rsHWe/UmSBsi3YZCkDjH0JalDDH1J6hBDX5I6xNCXpA4x9CWpQwx9SeoQQ1+SOsTQl6QOMfQlqUMMfUnqEENfkjrE0JekDjH0JalDDH1J6hBDX5I6ZE8+OWtrkkeTfGtK7T8m+XaSO5P8aZKDp2z7SJLxJPclOX1K/YxWG09y0eB/FEnSXPbkSv/zwBnTatuB46vql4HvAB8BSLIWOAc4ru3zX5Lsl2Q/4NPAmcBaYEMbK0kaojlDv6q+Bjw2rfblqnq2rd4MjLbl9cA1VfWTqnqQ3mflntS+xqvqgap6GrimjZUkDdEg5vTPB/53W14B7JyybaLVZqtLkoaor9BP8nvAs8AXBtMOJNmUZEeSHZOTk4M6rCSJPkI/yXnAu4HfqKpq5V3AyinDRltttvpLVNWWqhqrqrGRkZGFtidJmsGCQj/JGcCHgLOq6skpm7YB5yQ5MMlqYA1wK3AbsCbJ6iQH0Huyd1t/rUuS5mvZXAOSXA28Azg8yQRwMb27dQ4EticBuLmq/mVV3Z3kWuAeetM+m6vquXacC4Ebgf2ArVV19174eSRJL2PO0K+qDTOUr3iZ8ZcAl8xQvwG4YV7dSZIGylfkSlKHGPqS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdYuhLUocY+pLUIYa+JHWIoS9JHWLoS1KHGPqS1CGGviR1iKEvSR0yZ+gn2Zrk0STfmlI7NMn2JPe374e0epJ8Msl4kjuTnDBln41t/P1JNu6dH0eS9HL25Er/88AZ02oXATdV1RrgprYOcCa9z8VdA2wCPgO9Bwl6H7P4VuAk4OLnHygkScMzZ+hX1deAx6aV1wNXtuUrgbOn1K+qnpuBg5McBZwObK+qx6rqcWA7L30gkSTtZQud0z+yqh5qyw8DR7blFcDOKeMmWm22uiRpiPp+IreqCqgB9AJAkk1JdiTZMTk5OajDSpJYeOg/0qZtaN8fbfVdwMop40Zbbbb6S1TVlqoaq6qxkZGRBbYnSZrJQkN/G/D8HTgbgeun1M9td/GcDDzRpoFuBE5Lckh7Ave0VpMkDdGyuQYkuRp4B3B4kgl6d+FcClyb5ALge8B72/AbgHXAOPAk8D6AqnosyR8Ct7Vxf1BV058cliTtZXOGflVtmGXTqTOMLWDzLMfZCmydV3eSpIHyFbmS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdYuhLUocY+pLUIYa+JHWIoS9JHWLoS1KHGPqS1CGGviR1iKEvSR1i6EtSh/QV+kn+TZK7k3wrydVJlidZneSWJONJ/jjJAW3sgW19vG1fNYgfQJK05xYc+klWAL8NjFXV8cB+wDnAx4HLquoY4HHggrbLBcDjrX5ZGydJGqJ+p3eWAa9Osgx4DfAQ8CvAdW37lcDZbXl9W6dtPzVJ+jy/JGkeFhz6VbUL+ATwfXph/wRwO/DDqnq2DZsAVrTlFcDOtu+zbfxhCz2/JGn++pneOYTe1ftq4PXAa4Ez+m0oyaYkO5LsmJyc7PdwkqQp+pne+VXgwaqarKpngC8CpwAHt+kegFFgV1veBawEaNsPAnZPP2hVbamqsaoaGxkZ6aM9SdJ0/YT+94GTk7ymzc2fCtwDfBV4TxuzEbi+LW9r67TtX6mq6uP8kqR56mdO/xZ6T8jeAdzVjrUF+DDwwSTj9Obsr2i7XAEc1uofBC7qo29J0gIsm3vI7KrqYuDiaeUHgJNmGPsU8Ov9nE+S1B9fkStJHWLoS1KHGPqS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdYuhLUof09TYM0qL62EF74ZhPDP6Y0hLilb4kdYihL0kdYuhLUocY+pLUIYa+JHVIX6Gf5OAk1yX5dpJ7k7wtyaFJtie5v30/pI1Nkk8mGU9yZ5ITBvMjSJL2VL9X+pcDX6qqXwTeBNxL72MQb6qqNcBNvPCxiGcCa9rXJuAzfZ5bkjRPCw79JAcBb6d9Bm5VPV1VPwTWA1e2YVcCZ7fl9cBV1XMzcHCSoxbcuSRp3vq50l8NTAKfS/L1JJ9N8lrgyKp6qI15GDiyLa8Adk7Zf6LVJElD0k/oLwNOAD5TVW8BfswLUzkAVFUBNZ+DJtmUZEeSHZOTk320J0marp/QnwAmquqWtn4dvQeBR56ftmnfH23bdwErp+w/2movUlVbqmqsqsZGRkb6aE+SNN2CQ7+qHgZ2Jjm2lU4F7gG2ARtbbSNwfVveBpzb7uI5GXhiyjSQJGkI+n3Dtd8CvpDkAOAB4H30HkiuTXIB8D3gvW3sDcA6YBx4so2VJA1RX6FfVd8AxmbYdOoMYwvY3M/5JEn98RW5ktQhhr4kdYihL0kdYuhLUocY+pLUIYa+JHWIoS9JHWLoS1KHGPqS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdYuhLUof0HfpJ9kvy9ST/s62vTnJLkvEkf9w+VYskB7b18bZ9Vb/nliTNzyCu9D8A3Dtl/ePAZVV1DPA4cEGrXwA83uqXtXGSpCHqK/STjALvAj7b1gP8CnBdG3IlcHZbXt/WadtPbeMlSUPS75X+fwY+BPy0rR8G/LCqnm3rE8CKtrwC2AnQtj/RxkuShmTBoZ/k3cCjVXX7APshyaYkO5LsmJycHOShJanz+rnSPwU4K8l3gWvoTetcDhycZFkbMwrsasu7gJUAbftBwO7pB62qLVU1VlVjIyMjfbQnSZpuwaFfVR+pqtGqWgWcA3ylqn4D+CrwnjZsI3B9W97W1mnbv1JVtdDzS5Lmb2/cp/9h4INJxunN2V/R6lcAh7X6B4GL9sK5JUkvY9ncQ+ZWVX8O/HlbfgA4aYYxTwG/PojzSZIWxlfkSlKHGPqS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdYuhLUocY+pLUIYa+JHWIoS9JHTKQN1yT9ArwsYP2wjGfGPwx1Rev9CWpQwx9SeoQQ1+SOqSfD0ZfmeSrSe5JcneSD7T6oUm2J7m/fT+k1ZPkk0nGk9yZ5IRB/RCSpD3Tz5X+s8DvVtVa4GRgc5K19D4G8aaqWgPcxAsfi3gmsKZ9bQI+08e5JUkL0M8Hoz9UVXe05b8B7gVWAOuBK9uwK4Gz2/J64KrquRk4OMlRC+5ckjRvA5nTT7IKeAtwC3BkVT3UNj0MHNmWVwA7p+w20WqSpCHp+z79JD8H/Hfgd6rqR0l+tq2qKknN83ib6E3/cPTRR/fbniQtjiX6uoe+rvST7E8v8L9QVV9s5Ueen7Zp3x9t9V3Ayim7j7bai1TVlqoaq6qxkZGRftqTJE3Tz907Aa4A7q2q/zRl0zZgY1veCFw/pX5uu4vnZOCJKdNAkqQh6Gd65xTgN4G7knyj1T4KXApcm+QC4HvAe9u2G4B1wDjwJPC+Ps4tSVqABYd+Vf0lkFk2nzrD+AI2L/R80j5ric7tqpt8Ra4kdYjvsilp3+L/nPrilb4kdYihL0kdYuhLUocY+pLUIYa+JHWIoS9JHWLoS1KHGPqS1CGGviR1iKEvSR1i6EtShxj6ktQhhr4kdYihL0kdMvTQT3JGkvuSjCe5aNjnl6QuG2roJ9kP+DRwJrAW2JBk7TB7kKQuG/aV/knAeFU9UFVPA9cA64fcgyR1VnofXTukkyXvAc6oqve39d8E3lpVF04ZswnY1FaPBe4bcBuHAz8Y8DH3BvscLPscrH2hz32hR9g7ff58VY3MtGHJfVxiVW0Btuyt4yfZUVVje+v4g2Kfg2Wfg7Uv9Lkv9AjD73PY0zu7gJVT1kdbTZI0BMMO/duANUlWJzkAOAfYNuQeJKmzhjq9U1XPJrkQuBHYD9haVXcPswf24tTRgNnnYNnnYO0Lfe4LPcKQ+xzqE7mSpMXlK3IlqUMMfUnqEENfkjrkFR/6SX4xyYeTfLJ9fTjJLy12X/uq9ud5apKfm1Y/Y7F6mkmSk5Kc2JbXJvlgknWL3dfLSXLVYvcwlyT/qP1ZnrbYvUyV5K1J/l5bfnWSf5fkz5J8PMlBi93f85L8dpKVc4/ciz28kp/ITfJhYAO9t3uYaOVRereKXlNVly5Wb3sqyfuq6nOL3Qf0/sECm4F7gTcDH6iq69u2O6rqhMXs73lJLqb3/k7LgO3AW4GvAr8G3FhVlyxiewAkmX6rcoB3Al8BqKqzht7UDJLcWlUnteV/Qe/v/0+B04A/Wyq/Q0nuBt7U7hDcAjwJXAec2ur/ZFEbbJI8AfwY+GvgauBPqmpyqD28wkP/O8BxVfXMtPoBwN1VtWZxOttzSb5fVUcvdh8ASe4C3lZVf5tkFb1fqj+qqsuTfL2q3rKoDTatzzcDBwIPA6NV9aMkrwZuqapfXtQG6T1IAvcAnwWKXuhfTe+ChKr6i8Xr7gVT/16T3Aasq6rJJK8Fbq6qf7C4HfYkubeqfqktv+gCJMk3qurNi9fdC5J8HfiHwK8C/xQ4C7id3t/9F6vqb/Z2D0vubRgG7KfA64HvTasf1bYtCUnunG0TcOQwe5nDq6rqbwGq6rtJ3gFcl+Tn6fW6VDxbVc8BTyb566r6EUBV/V2SpfL3PgZ8APg94N9W1TeS/N1SCfspXpXkEHpTwXn+qrSqfpzk2cVt7UW+NeV/xd9MMlZVO5K8EXhmrp2HqKrqp8CXgS8n2Z/e/0o3AJ8AZny/nEF6pYf+7wA3Jbkf2NlqRwPHABfOutfwHQmcDjw+rR7g/w2/nVk9kuTNVfUNgHbF/25gK7Akrviap5O8pqqepHdVBUCb210Sod9+8S9L8ift+yMszd/Hg+hdiQaoJEdV1UPtOZ2l9ED/fuDyJL9P783L/irJTnq/9+9f1M5e7EV/Zm0WYhuwLclrhtLAK3l6ByDJq+i9pfOKVtoF3NauBJeEJFcAn6uqv5xh23+rqn+2CG29RJJRelfRD8+w7ZSq+r+L0NZLJDmwqn4yQ/1w4KiqumsR2npZSd4FnFJVH13sXvZEC6gjq+rBxe5lqvZk7mp6D6ATVfXIIrf0IkneWFXfWdQeXumhL0l6wSv+lk1J0gsMfUnqEENfkjrE0JekDjH0JalD/j9bz/kf8CDV3AAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mxK_lETT9hE0"
      },
      "source": [
        "# Model building"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vHBmUf1t9hE1"
      },
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Input((21 * 2, )),\n",
        "    tf.keras.layers.Dropout(0.0),\n",
        "    tf.keras.layers.Dense(32, activation='relu'),\n",
        "    tf.keras.layers.Dropout(0.0),\n",
        "    tf.keras.layers.Dense(32, activation='relu'),\n",
        "    tf.keras.layers.Dropout(0.0),\n",
        "    tf.keras.layers.Dense(16, activation='relu'),\n",
        "    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')\n",
        "])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ypqky9tc9hE1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c42f3550-ceee-45b8-d40d-d99fd84a2616"
      },
      "source": [
        "model.summary()  # tf.keras.utils.plot_model(model, show_shapes=True)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"sequential_194\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "dropout_582 (Dropout)        (None, 42)                0         \n",
            "_________________________________________________________________\n",
            "dense_776 (Dense)            (None, 32)                1376      \n",
            "_________________________________________________________________\n",
            "dropout_583 (Dropout)        (None, 32)                0         \n",
            "_________________________________________________________________\n",
            "dense_777 (Dense)            (None, 32)                1056      \n",
            "_________________________________________________________________\n",
            "dropout_584 (Dropout)        (None, 32)                0         \n",
            "_________________________________________________________________\n",
            "dense_778 (Dense)            (None, 16)                528       \n",
            "_________________________________________________________________\n",
            "dense_779 (Dense)            (None, 8)                 136       \n",
            "=================================================================\n",
            "Total params: 3,096\n",
            "Trainable params: 3,096\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MbMjOflQ9hE1"
      },
      "source": [
        "# Model checkpoint callback\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
        "    model_save_path, verbose=1, save_weights_only=False, save_best_only=True)\n",
        "# Callback for early stopping\n",
        "es_callback = tf.keras.callbacks.EarlyStopping(patience=50, verbose=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c3Dac0M_9hE2"
      },
      "source": [
        "# Model compilation\n",
        "model.compile(\n",
        "    optimizer='adam',\n",
        "    loss='sparse_categorical_crossentropy',\n",
        "    metrics=['accuracy']\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7XI0j1Iu9hE2"
      },
      "source": [
        "# Model training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "WirBl-JE9hE3"
      },
      "source": [
        "model.fit(\n",
        "    X_train,\n",
        "    y_train,\n",
        "    epochs=1000,\n",
        "    batch_size=64,\n",
        "    validation_data=(X_test, y_test),\n",
        "    callbacks=[cp_callback, es_callback]\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RBkmDeUW9hE4"
      },
      "source": [
        "# Loading the saved model\n",
        "model = tf.keras.models.load_model(model_save_path)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pxvb2Y299hE3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7015e279-0501-4f24-d1b5-90652d2de17d"
      },
      "source": [
        "# Model evaluation\n",
        "# TODO Test on loaded model\n",
        "val_loss, val_acc = model.evaluate(X_test, y_test, batch_size=64)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "25/25 [==============================] - 0s 2ms/step - loss: 0.0061 - accuracy: 0.9974\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tFz9Tb0I9hE4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bb8a62a9-bf7d-4d99-8099-bda4e2e1c73a"
      },
      "source": [
        "# Inference test\n",
        "predict_result = model.predict(np.array([X_test[0]]))\n",
        "print(np.squeeze(predict_result))\n",
        "print(np.argmax(np.squeeze(predict_result)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[1.0000000e+00 8.5401677e-12 1.5222527e-14 7.2429063e-11 3.3295724e-09\n",
            " 2.5334020e-27 2.3049776e-16 3.3789385e-19]\n",
            "0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3U4yNWx9hE4"
      },
      "source": [
        "# Confusion matrix"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AP1V6SCk9hE5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 646
        },
        "outputId": "efce96b9-ca1c-44a5-ef77-58f1d5fc154c"
      },
      "source": [
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics import confusion_matrix, classification_report\n",
        "\n",
        "def print_confusion_matrix(y_true, y_pred, report=True):\n",
        "    labels = sorted(list(set(y_true)))\n",
        "    cmx_data = confusion_matrix(y_true, y_pred, labels=labels)\n",
        "    \n",
        "    df_cmx = pd.DataFrame(cmx_data, index=labels, columns=labels)\n",
        " \n",
        "    fig, ax = plt.subplots(figsize=(7, 6))\n",
        "    sns.heatmap(df_cmx, annot=True, fmt='g' ,square=False)\n",
        "    ax.set_ylim(len(set(y_true)), 0)\n",
        "    plt.show()\n",
        "    \n",
        "    if report:\n",
        "        print('Classification Report')\n",
        "        print(classification_report(y_test, y_pred))\n",
        "\n",
        "Y_pred = model.predict(X_test)\n",
        "y_pred = np.argmax(Y_pred, axis=1)\n",
        "\n",
        "print_confusion_matrix(y_test, y_pred)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAFlCAYAAAAjyXUiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df5xWc/7/8cfrmqZSEflR8yNbu/m5rKLCskQorYSPDbt+7cd3W4tdrV2xWGSxflt28dn8SJYoPyOtzSbSEoVQU0spmpkSQvqh5sfr+8ecZi/tzDV1XWeu6z1Xz7vbubnOuc51znPOcL3mfd7vc465OyIiInFI5DqAiIjkDxUVERGJjYqKiIjERkVFRERio6IiIiKxUVEREZHYtGruHVR9+kGwY5a3Kv5BriM0ynIdIIVgf6EiOVC9viK2/13j+L4s3OHbOf36aPaiIiIim6i2JtcJMqbTXyIiEhu1VEREQuG1uU6QMRUVEZFQ1KqoiIhITDwPWirqUxERkdiopSIiEgqd/hIRkdjkwekvFRURkVDoOhUREYmN12Y+bQIzKzCzt8xsYjTf3cxeM7MFZjbOzFpHy9tE8wui97s1tW0VFRGRLc/5wLyk+euBW929B/A5cFa0/Czg82j5rdF6KamoiIiEorY286kJZlYK/BC4J5o34HDgsWiVMcBx0esh0TzR+/2j9RulPhURkUBk6TqVPwEjgK2j+e2BL9y9OpovB0qi1yXAkrpsXm1mX0brf9rYxtVSEREJRQwtFTMbZmazkqZhGzZvZscAy939jeb6EdRSEREJRQwtFXcfBYxq5O2DgGPNbBDQFtgGuA3Y1sxaRa2VUqAiWr8C6AqUm1kroCPwWar9B9lSqamp4cQzz+WcC68AYOxjT3P00P9lr4OO5vMvvqxf76tVqzl3xBWccMY5DPnJz3ny2ck5yTvgqH7MnTON+WXTGXHhuTnJ0JDS0mKen/wob789ldmzX+CX553V9IeyKNTjBmFng7DzKVu43P137l7q7t2Ak4EX3P0nwFTgxGi1M4AJ0euno3mi919w95TPfAmyqDz46AS+3W3n+vle39uTe277I8VddvrGeg8//gzf6bYzT4y5k9F/uZ4b/3w3VVVVWc2aSCS4/bZrOGbwqey9z2GcdNJx7LHHLlnN0Jjq6mpGjBjJPvscxsEHD+bsX5wZTLaQj1vI2SDsfMqWodqazKf0XARcYGYLqOszuTdafi+wfbT8AuDipjbUZFExs93N7CIzuz2aLjKzPdJN3pRlyz9h2iuv8z+DB9Qv22PXHpQUdW4oG6vXrMXdWbP2azpuszUFBQXNFa1Bffv0YuHCxSxa9BFVVVWMHz+BY5Oy59KyZct5a/YcAFatWs38+e9TXNwlx6nqhHzcQs4GYedTtgxl6ToVAHd/0d2PiV5/4O593b2Hu//I3ddFy7+O5ntE73/Q1HZTFhUzuwh4hLqn274eTQY8bGZNVqx0XH/bX7ngnLMwa7oR9eP/GcwHi5dw2JCfcPzpv+Di4WeTSGS38VVc0oUl5ZX18+UVS4P54k72rW+V0nOfvXj99bdyHQUI+7iFnA3CzqdsGcrCkOLm1tQ38FlAH3e/zt0fjKbrgL785+KY/5I8+uCeBx7e5DAv/us1Om23Ld/dfdOapP96/Q123+XbTJ3wEI/ffwfX3nInq1av3uT9bSnat2/H+HF385vfXsFXX63KdRwRyWNNjf6qBYqBDzdaXhS916Dk0QdVn36QslMn2VvvlPHi9Bm8/OpM1q2vYvXqNVw08gauv2JEg+s/+ezz/L9Th2Jm7FxaTElRFxZ9WM7ee+62qbvMWGXFMrqWFtfPl5YUUVm5LGv7b0qrVq0YP+5uHn74SZ566u+5jlMv5OMWcjYIO5+yZSgPbijZVEtlODDFzP5uZqOi6TlgCnWX+cfq17/4KVOeepDJj4/hxpEX03e/fRotKABFnXdkxhuzAfh0xecs/qic0iw3Z2fOmk2PHt3p1q0rhYWFDB06hGcm5mYUWkPuHnUz8+cv4E+3NTbCMDdCPm4hZ4Ow8ylbhvLg9FfKloq7P2dmu1J3umvDFZYVwEx3z9rtNB98dAKjH3qUT1d8zgmnn8MPDuzDVb8bztln/phLr7mZ40/7Be7Or8/5X7bbtmO2YgF1w5/PH34Zk54dS0Eiwf1jxlFW9l5WMzTmoO/34dRTT+Tdd8uYNbPuf57Lfn8dzz33Qo6ThX3cQs4GYedTtsxk8Wu12VgTQ44ztjmnv7Jtq+If5DpCo1LeXCfHgv2FiuRA9fqK2P53/Xr2xIz/92rb85icfn0EeZ2KiIi0TLpNi4hIKALoE8mUioqISCjyYPSXioqISCjy4HHCKioiIqHIg5aKOupFRCQ2aqmIiIRCHfUiIhKbPDj9paIiIhKKPGipqE9FRERio5aKiEgo8qCloqIiIhKIfLihpIqKiEgo1FIREZHY5MHoL3XUi4hIbJq9pRLyM0tWvzM21xEa1f57P851BBHJNp3+EhGR2OTB6S8VFRGRUKilIiIiscmDloo66kVEJDZqqYiIhEKnv0REJDYqKiIiEhv1qYiIiPyHioqISChqazOfUjCztmb2upm9bWZzzWxktPx+M1tkZrOjqWe03MzsdjNbYGbvmNm+Tf0IOv0lIhKK5j/9tQ443N1XmVkhMN3M/h69d6G7P7bR+kcDu0TT/sBd0b8bpaIiIhKKZu6od3cHVkWzhdHkKT4yBHgg+twMM9vWzIrcfWljH9DpLxGRUHht5lMTzKzAzGYDy4Hn3f216K1rolNct5pZm2hZCbAk6ePl0bJGqaiIiOQRMxtmZrOSpmHJ77t7jbv3BEqBvma2F/A7YHegD9AJuCjd/ev0l4hIKGI4/eXuo4BRm7DeF2Y2FRjo7jdFi9eZ2Wjgt9F8BdA16WOl0bJGqaUiIhKK5h/9taOZbRu93go4EphvZkXRMgOOA+ZEH3kaOD0aBXYA8GWq/hRoQUVlwFH9mDtnGvPLpjPiwnOzvv9166v48W+v5cTzr+L4867gjrFPA/DaO/MZ+us/cPwvr+TSP42muuY/z5ie+e6/+dHwuvV/esmNWc8MuT9uqYSc7e5RN1NZ/jaz35qS6ygNCvnYKVsG3DOfUisCpprZO8BM6vpUJgIPmdm7wLvADsDV0fqTgA+ABcDdwDlN7cC86RAZadW6JOMdJBIJ5s19mYGDTqG8fCkzXp3Eqaedw7x572e03c15SJe7s/brdbTbqi1V1dWccfENjDjrJC68cRR3/+ECupV05o6HJlC00/accOTBrFy1htMvup67rvwVRTtuz2dfrGT7bbfZ5P3F8ZCu5jpucQg5G8APDt6fVatWM3r0bfTs1T/Xcb4h5GO3JWarXl9hMUVk7cNXZPx9udUpI2PLk44W0VLp26cXCxcuZtGij6iqqmL8+AkcO3hAVjOYGe22agtAdU0N1TU1JBJGYWEB3Uo6A3BAzz3556tvAjBp2uv0P7AXRTtuD7BZBSUuIRy3xoScDeDl6a+x4vMvch2jQSEfO2WTFlFUiku6sKS8sn6+vGIpxcVdsp6jpqaWHw2/in6n/5YDe+7J3rt2p6amlrnvLwbg+VfeYNmnKwD4sPJjVq5aw/9eehMnXXA1T7/watbzhnLcGhJyttCFfOyULUPN3KeSDWmP/jKzn7r76DjDhK6gIMGjf7qclavW8Os/3smCjyq54bc/44b7xlNVVc2BPfekIFFXp2tqaihb+CF3/+EC1q1fz2kjrud7u327vlUjIvJf8uCGkpkMKR4JNFhUonHRwwCsoCOJRPsMdgOVFcvoWlpcP19aUkRl5bKMtpmJbTq0o8/eu/OvN+dy5vFHMeaPIwB45a25fFj5MQCdt9+Ojlt3oF3bNrRr24b9vrsL7y1ektWiEtpxSxZyttCFfOyULUMBtDQylfL0V3R1ZUPTu0Cj347uPsrde7t770wLCsDMWbPp0aM73bp1pbCwkKFDh/DMxMkZb3dzrPjyK1auWgPA1+vW8+rbZXQv7cJnX6wEYH1VFfc98Q9+NPBQAA7bvydvzVtAdU0Na9et4533FtG9tCirmUM4bo0JOVvoQj52yiZNtVQ6AwOAzzdabsArzZKoATU1NZw//DImPTuWgkSC+8eMo6zsvWztHoBPP/+Sy/40mpraWmrdGXBQbw7t8z1uHv0Y02a9Q22tM/ToQ9n/e7sD8O2uRRzU67uc+KursIRxwpEHs8u3Ut7dIHYhHLfGhJwN4MG/3cGhhxzIDjt0YvEHsxh51U2Mvv+RXMcCwj52ypahZh6Nmw0phxSb2b3AaHef3sB7Y929yXGvcQwpbi6bM6Q42+IYUiwizS/WIcWjR2Q+pPinN+R0SHHKloq7n5XiPX3riYjEKQ/6VHTvLxGRUOTB6K8WcZ2KiIi0DGqpiIgEwmuD7YLeZCoqIiKhUJ+KiIjEJg/6VFRURERCkQenv9RRLyIisVFLRUQkFOpTERGR2KioiIhIbPLg3l/qUxERkdiopSIiEgqd/hIRkdjkwZBiFRURkVDo4kcREYmNWiotW8gPwvrqHyNzHaFRWw+4ItcRRCRQW3RREREJiaujXkREYqPTXyIiEps86KjXxY8iIhIbtVREREKh018iIhKbPOio1+kvEZFQ1HrmUwpm1tbMXjezt81srpmNjJZ3N7PXzGyBmY0zs9bR8jbR/ILo/W5N/QgqKiIiofDazKfU1gGHu/s+QE9goJkdAFwP3OruPYDPgbOi9c8CPo+W3xqtl5KKiojIFsLrrIpmC6PJgcOBx6LlY4DjotdDonmi9/ubmaXah4qKiEgomvn0F4CZFZjZbGA58DywEPjC3aujVcqBkuh1CbAEIHr/S2D7VNtXR72ISCDiuKLezIYBw5IWjXL3UfX7cK8BeprZtsCTwO4Z7zSJioqISChiGFIcFZBRm7DeF2Y2FTgQ2NbMWkWtkVKgIlqtAugKlJtZK6Aj8Fmq7er0l4hIKJp/9NeOUQsFM9sKOBKYB0wFToxWOwOYEL1+Oponev8F99TPPFZLRURky1EEjDGzAuoaFePdfaKZlQGPmNnVwFvAvdH69wJ/M7MFwArg5KZ2oKIiIhKKZr73l7u/A/RqYPkHQN8Gln8N/Ghz9tFiTn8NOKofc+dMY37ZdEZceG6u43zD3aNuprL8bWa/NSUn+19XVc1P/vg3hv7hfk4YeR93PjMdAHfnz0+9zLGX38PxV97L2BfeAOCrtev41R1P1K//1Cvv5iR3yL/TkLNB2PmULQNZGP3V3KyJ02MZa9W6JOMdJBIJ5s19mYGDTqG8fCkzXp3Eqaedw7x578cRMWM/OHh/Vq1azejRt9GzV/9Ytrk5D+lyd9auq6Jd29ZU1dTw0xsfZsTQw/lg2WfM+vdHXHXGIBIJY8XK1XTapj33/H0Gq9auY/gJh7LiqzUcd8W9TLnhHApbFWzS/uJ4SFfIv9OQs0HY+bbEbNXrK1Jet7E5vho+OOPvy63/9ExsedLRIloqffv0YuHCxSxa9BFVVVWMHz+BYwcPyHWsei9Pf40Vn3+Rs/2bGe3atgaguqaW6poazODRl2Yz7IffJ5Go+2+s0zbto/Vh9dfro2K0no7t21KQyO5/CiH/TkPOBmHnUzZp8pvEzHY3s/5m1mGj5QObL9Y3FZd0YUl5Zf18ecVSiou7ZGv3LUJNbS1Dr76fwy+8gwP26Mbe3Ysp//QL/jFrPj++9gHO/fNjfPjx5wCc3G9fFi37jCMvuosT/3A/Fw49vL7wZEvIv9OQs0HY+ZQtQ3lw+itlUTGzX1E3tOyXwBwzG5L09rXNGUw2T0EiwfjLzuQffzybOYuXsqDiE9ZX19CmsBVjLzmdEw7+Hlf+7e8AvDJ3EbuV7sTz1/+CcZeewXWPTGHV2nU5/glEhNrazKcca6ql8jNgP3c/DugH/N7Mzo/ea/RPWzMbZmazzGxWbe3qjENWViyja2lx/XxpSRGVlcsy3m4+2qZdW/rstjP/mruIzttuTf9euwBweM9deL/8EwAmvDqH/r12xczYeaftKNmhI4uWrchqzpB/pyFng7DzKVuG8r2lAiQ23HzM3RdTV1iONrNbSFFU3H2Uu/d2996JRPuMQ86cNZsePbrTrVtXCgsLGTp0CM9MnJzxdvPFiq/WsHLN1wB8vb6KGfMW073L9hzWswcz/70EgFnvLWHnzp0AKOq0Na/N/xCAz1auZvGyFZTu2DGrmUP+nYacDcLOp2wZyoOi0tR1Kh+bWU93nw3g7qvM7BjgPmDvZk8Xqamp4fzhlzHp2bEUJBLcP2YcZWXvZWv3TXrwb3dw6CEHssMOnVj8wSxGXnUTo+9/JGv7//TLVfx+zN+pra2l1uGo/XbjkO99h549Srjkvmd5cMos2rUp5IrT6jolfzbo+1w+ZhInXjUaB4afcAjbdWiXtbwQ9u805GwQdj5lk5RDis2sFKh29/9qI5rZQe7+r6Z2EMeQ4i3R5gwpzrY4hhSL5Is4hxSv/PmAjL8vt/nrP3I6pDhlS8Xdy1O812RBERGRzRDA6atM6TYtIiKhUFEREZG4eB4UlRZxRb2IiLQMaqmIiIQiD1oqKioiIqHI/QXxGVNREREJhPpUREREkqilIiISijxoqaioiIiEQn0qIiISl3zoU1FREREJRR60VNRRLyIisVFLRUQkEDr9JSIi8cmD018qKoEK+Zkl4zsdmusIjRq64qVcRxBJm6uoiIhIbPKgqKijXkREYqOWiohIIHT6S0RE4qOiIiIiccmHlor6VEREthBm1tXMpppZmZnNNbPzo+VXmlmFmc2OpkFJn/mdmS0ws3+b2YCm9qGWiohIILLQUqkGfuPub5rZ1sAbZvZ89N6t7n5T8spmtidwMvBdoBj4p5nt6u41je1ARUVEJBDNXVTcfSmwNHr9lZnNA0pSfGQI8Ii7rwMWmdkCoC/wamMf0OkvEZFQuGU+bSIz6wb0Al6LFp1nZu+Y2X1mtl20rARYkvSxclIXIRUVEZFQeG3mk5kNM7NZSdOwjfdjZh2Ax4Hh7r4SuAv4DtCTupbMzen+DDr9JSKSR9x9FDCqsffNrJC6gvKQuz8RfebjpPfvBiZGsxVA16SPl0bLGqWWiohIILzWMp5SMTMD7gXmufstScuLklY7HpgTvX4aONnM2phZd2AX4PVU+1BLRUQkEFkY/XUQcBrwrpnNjpZdApxiZj0BBxYDPwdw97lmNh4oo27k2LmpRn6BioqISDB8Mzra09u+Twca2smkFJ+5BrhmU/ehoiIiEghdUS8iIpKkxRSVAUf1Y+6cacwvm86IC8/NdZxvULbU9r11GIPm3EX/F6+vX9b3r7/k8H9ey+H/vJYBM2/j8H9eC4C1KmC/28+m/9TrOGLajez6y2NzkjmE45ZKyPmULX3N3VGfDS2iqCQSCW6/7RqOGXwqe+9zGCeddBx77LFLrmMByrYpPhw3jVdOuf4by17/+Z954YhLeOGIS6h89nUqJ80EoGTw/iRaFzLlsIuZOuBSup/en3Zdd8hq3lCOW2NCzqdsmXHPfMq1JouKmfU1sz7R6z3N7ILkm41lQ98+vVi4cDGLFn1EVVUV48dP4NjBTd7XLCuUrWmfzZjP+i9WNfp+yeADWPJkdNcHd1q1a4MVJCho25ra9dVUfbU2S0nrhHLcGhNyPmXLTN63VMzsCuB24C4z+yPwF6A9cLGZXZqFfAAUl3RhSXll/Xx5xVKKi7tka/cpKVtmtj9gd9Z9+iWrFy0DoGLi61SvWcegd+5k4Bu38/5dz1L1xeqsZgr9uIWcT9mkqdFfJ1J32X4bYBlQ6u4rzewm6u4X0+Aws+i2AMMArKAjiUT7+BJLXul6/PdZ8uQr9fPb9foOXlPLpH3OpfW27TnkqctZPm0Oaz5ansOUItkRQksjU02d/qp29xp3XwMsjO4Rg7uvJcUzytx9lLv3dvfecRSUyopldC0trp8vLSmisnJZxtuNg7KlzwoSFA/qQ8WEGfXLup7wfT6e+jZeXcO6T1fy2cz32K5n96zmCv24hZxP2TKzJfSprDezdtHr/TYsNLOOZPHBlzNnzaZHj+5069aVwsJChg4dwjMTJ2dr9ykpW/p2OmQvvlpQydqlK+qXra34jJ0O/i4ABe3a0Gm/Hnz1fmVjm2gWoR+3kPMpW2byoU+lqdNfh0T30cf9G5flFAJnNFuqjdTU1HD+8MuY9OxYChIJ7h8zjrKy97K1+5SUrWl97jqPHb+/B607bc3Rb/6Zshsf58OHX6T0uAMpTzr1BbDwvsnsd9vZHPHSDWDw4SPTWDlvSSNbbh6hHLfGhJxP2TLT3FfUZ4N5M7eXWrUuCaBBJnEa3+nQXEdo1NAVL+U6gmxhqtdXxFYJFu41IOPvy+/M+UdOK5Nu0yIiEoh8uE2LioqISCBq8+D0l4qKiEgg8qFPRUVFRCQQIYzeylSLuPeXiIi0DGqpiIgEIoSLFzOloiIiEoh8OP2loiIiEoh8GP2lPhUREYmNWioiIoHQkGIREYmNOupFRCQ2+dCnoqIiIhKIfDj9pY56ERGJjVoqIiKBUJ+KbJFCfmbJztvslOsIjfpo5fJcR5DAqU9FRERikw99KioqIiKByIeWijrqRUQkNioqIiKB8BimVMysq5lNNbMyM5trZudHyzuZ2fNm9n707+2i5WZmt5vZAjN7x8z2bepnUFEREQlErVvGUxOqgd+4+57AAcC5ZrYncDEwxd13AaZE8wBHA7tE0zDgrqZ2oKIiIhIId8t4Sr19X+rub0avvwLmASXAEGBMtNoY4Ljo9RDgAa8zA9jWzIpS7UNFRURkC2Rm3YBewGtAZ3dfGr21DOgcvS4BliR9rDxa1igVFRGRQNTGMJnZMDOblTQN23g/ZtYBeBwY7u4rk99z903pnmmUhhSLiATCyXxIsbuPAkY19r6ZFVJXUB5y9yeixR+bWZG7L41Ob224UrcC6Jr08dJoWaPUUhERCUStZz6lYmYG3AvMc/dbkt56Gjgjen0GMCFp+enRKLADgC+TTpM1SC0VEZFA1MbQUmnCQcBpwLtmNjtadglwHTDezM4CPgSGRu9NAgYBC4A1wE+b2oGKiojIFsLdp0Ojlat/A+s7cO7m7ENFRUQkEHH0qeSaioqISCBqcx0gBioqIiKByIeWikZ/iYhIbFpMURlwVD/mzpnG/LLpjLhws/qNmp2ypSekbK3btObJyX/j2RfH8dz0xxh+0dkAjHvmXiZOfYSJUx/h1TmT+b8HbmliS9kR0rHbmLKlL46LH3PNvJmfX9mqdUnGO0gkEsyb+zIDB51CeflSZrw6iVNPO4d5896PI6Ky5VG2TJ782K79VqxZvZZWrVox/tn7uOqSG5n9xrv17985+iae//uLPDl+Ylrbj+vJj1vi7zXkbNXrK2I7ZzWp88kZf18O+viRnJ5D2+yWipk90BxBUunbpxcLFy5m0aKPqKqqYvz4CRw7eEC2YzRI2dITYrY1q9cC0KqwFa0KW5H8B1eHDu058Ad9eH7S1FzFqxfisdtA2TLjWMZTrqUsKmb29EbTM8AJG+azlJHiki4sKa+sny+vWEpxcZds7T4lZUtPiNkSiQQTpz7CzHlT+NeLM3j7zTn17x056DBemfY6q1atzmHCOiEeuw2ULTO1lvmUa02N/ioFyoB7qLvBmAG9gZtTfSi6gdkwACvoSCLRPvOkIs2straWYw47ma236cD/PXALu+7+Hd6bvxCAwScMZPyDT+Y4oUj4mjr91Rt4A7iUunu+vAisdfeX3P2lxj7k7qPcvbe7946joFRWLKNraXH9fGlJEZWVyzLebhyULT0hZ/tq5SpmTJ/FIf2/D8B2nbZln32/ywvPv5zjZHVCPnbKlplaLOMp11IWFXevdfdbqbvfy6Vm9hdycG3LzFmz6dGjO926daWwsJChQ4fwzMTJ2Y7RIGVLT2jZOm2/HVtv0wGANm3bcPCh+/PB+4sBOPrYI3hh8susX7c+Z/mShXbskilbZpr7ccLZsEkFwt3LgR+Z2Q+BlU2tH7eamhrOH34Zk54dS0Eiwf1jxlFW9l62YzRI2dITWradOu/AjX+5ioKCBJZIMGnC87wwua5lcszxA/i/20bnLNvGQjt2yZQtMyEMCc5UixhSLLKpMhlS3NziGlIsYYlzSPFjRT/J+PvyxKUPtawhxSIiIo3Rvb9ERAKRD6d1VFRERAKRD30qKioiIoEI4eLFTKlPRUREYqOWiohIIEK4eDFTKioiIoFQR72IiMQmH/pUVFRERAKRD6O/1FEvIiKxUUtFRCQQ6lMREZHYqE9FRERikw99KioqIiKByIeioo56ERGJjVoqkldCfmZJ7x12yXWERs369P1cRxDA1aciIiJx0ekvERGJTW0MU1PM7D4zW25mc5KWXWlmFWY2O5oGJb33OzNbYGb/NrMBTW1fRUVEZMtyPzCwgeW3unvPaJoEYGZ7AicD340+c6eZFaTauIqKiEggPIapyX24TwNWbGKkIcAj7r7O3RcBC4C+qT6goiIiEohay3zKwHlm9k50emy7aFkJsCRpnfJoWaNUVEREAhFHn4qZDTOzWUnTsE3Y9V3Ad4CewFLg5nR/Bo3+EhEJRByjv9x9FDBqMz/z8YbXZnY3MDGarQC6Jq1aGi1rlFoqIiJbODMrSpo9HtgwMuxp4GQza2Nm3YFdgNdTbUstFRGRQGTjLsVm9jDQD9jBzMqBK4B+ZtYzirAY+DmAu881s/FAGVANnOvuNam2r6IiIhKIbNyl2N1PaWDxvSnWvwa4ZlO3r6IiIhKIfLiiXkVFRCQQ+fCQLnXUi4hIbNRSEREJRG0etFVUVEREApEPfSot5vTXgKP6MXfONOaXTWfEhefmOs43KFt6lG3TddimA9eOGskj0x7gkZfGsNd+e3Le78/mkWkP8OA/7+W6e/9Ah2065DomEN6xSxZyNsjOvb+am7k3b4xWrUsy3kEikWDe3JcZOOgUysuXMuPVSZx62jnMm5f7Bwspm7Jtqkwe0vX7P13M26+/y9Njn6VVYSvabtWWPXvtzhvT36KmpoZzL627E8cd12zWhdT14npI15b4e61eXxHbQEjnaasAABCoSURBVOCrvvWTjL8vL//woZw+6qtFtFT69unFwoWLWbToI6qqqhg/fgLHDm7ytv5ZoWzpUbZN137r9vQ6YB+eHvssANVV1axauYrXX5pFTU3ddWhz3ihjp6Idc5Zxg9COXbKQs22QjeepNLfNKipmdrCZXWBmRzVXoIYUl3RhSXll/Xx5xVKKi7tkM0KjlC09yrYZeXYu4vPPvuD3t17MmMl3c8lNF9J2q7bfWGfwKYN49YWUd8/IitCOXbKQs22Q47sUxyJlUTGz15Ne/wz4C7A1cIWZXdzM2UQEKCgoYLe9d+WJByZwxlE/Y+2atZx+3o/r3z/zV6dSXV3Dc088n8OUEodaPOMp15pqqRQmvR4GHOnuI4GjgJ809qHkWy/X1q7OOGRlxTK6lhbXz5eWFFFZuSzj7cZB2dKjbJtu+dJP+GTpJ8x9ax4AL0x8id32ruuf+eHQgRx0xIFccd7VOcuXLLRjlyzkbBvkQ0d9U0UlYWbbmdn21HXqfwLg7qupu7lYg9x9lLv3dvfeiUT7jEPOnDWbHj26061bVwoLCxk6dAjPTJyc8XbjoGzpUbZNt+KTFXxcuZydv1N3B/I+P9iPRe9/yAH9+nLqOSdz4ZmXsG7tupzlSxbasUsWcrZ80tR1Kh2BNwAD3MyK3H2pmXWIlmVFTU0N5w+/jEnPjqUgkeD+MeMoK3svW7tPSdnSo2yb5+bLbmfkXy6jsLAVFR8t5epfX8d9k/5K6zaF3D6u7nlKc94o44aLb8lpzhCP3QYhZ9sghI72TKU1pNjM2gGdo2cWpxTHkGKRfJDJkOLmFteQ4i1RnEOKL+p2Ssbfl9cvfjin3fVpXVHv7muAJguKiIhsunz4C1y3aRERCUQ+nP5qERc/iohIy6CWiohIIEK4ziRTKioiIoFo+SVFRUVEJBjqUxEREUmiloqISCA8D06AqaiIiAQiH05/qaiIiARCo79ERCQ2Lb+kqKNeRERipJaKiEggdPpLRERio456ERGJjYYUi4hIbNRSEZFNFvKDsL7dsSjXERr1wZdLcx0hr5jZfcAxwHJ33yta1gkYB3QDFgND3f1zMzPgNmAQsAY4093fTLV9jf4SEQmEx/DPJrgfGLjRsouBKe6+CzAlmgc4GtglmoYBdzW1cRUVEZFA1MYwNcXdpwErNlo8BBgTvR4DHJe0/AGvMwPY1sxSNmt1+ktEJBC1nrOO+s7uvuE84zKgc/S6BFiStF55tKzRc5JqqYiI5BEzG2Zms5KmYZvzeXd3Mri4Xy0VEZFAxNFOcfdRwKjN/NjHZlbk7kuj01vLo+UVQNek9UqjZY1SS0VEJBC1eMZTmp4GzohenwFMSFp+utU5APgy6TRZg9RSEREJRDYufjSzh4F+wA5mVg5cAVwHjDezs4APgaHR6pOoG068gLohxT9tavsqKiIigcjGxY/ufkojb/VvYF0Hzt2c7ev0l4iIxEYtFRGRQOguxSIiEhvdUFJERGKjG0qKiEhsPHdX1MdGHfUiIhIbtVRERAKRDx31LaalMuCofsydM435ZdMZceFmDZtudsqWHmVLX2j5EokET73wEH996FYAbrrrDzz36uNMnDaOa2+7nFatCnKcsE5ox21j2bhLcXNrEUUlkUhw+23XcMzgU9l7n8M46aTj2GOPXXIdC1C2dClb+kLMd8awU1j43qL6+Wcef46BB/4PxxxyEm3btuFHpx6X4tPZEeJx21iWnqfSrFpEUenbpxcLFy5m0aKPqKqqYvz4CRw7eECuYwHKli5lS19o+ToX7US/Iw/i0Qefql/20j//Vf/6nTfn0qW4c0MfzarQjlu+SllUzGx/M9smer2VmY00s2fM7Hoz65idiFBc0oUl5ZX18+UVSyku7pKt3aekbOlRtvSFlu/Sa37DDSNvp7b2v/9KbtWqgCFDB/HyC6/kINk3hXbcGpLDG0rGpqmWyn3U3UQM6p5T3BG4Plo2uhlziUgL0O/Ig/nskxXMfWd+g+9fecPFzHz1TWbNmJ3lZC2Tu2c85VpTo78S7l4dve7t7vtGr6ebWaP/lUQPhRkGYAUdSSTaZxSysmIZXUuL6+dLS4qorFyW0TbjomzpUbb0hZRvv/33of/AQzj0iINo07Y1HTp04MY7r+LCcy7nvN/+jE7bb8fvf3NtTrJtLKTj1pgQOtoz1VRLZY6ZbbjV8dtm1hvAzHYFqhr7kLuPcvfe7t4704ICMHPWbHr06E63bl0pLCxk6NAhPDNxcsbbjYOypUfZ0hdSvpuvvoND9vkhh+93LL/+2aXMmD6TC8+5nB+dOoSDDzuAX//80iD+eoawjltj8qGjvqmWyv8DbjOzy4BPgVfNbAl1zyz+f80dboOamhrOH34Zk54dS0Eiwf1jxlFW9l62dp+SsqVH2dIXej6AkTf+jsolyxj/9/sAmDxxKnfcfE9OM7WE45YPbFP+iog667tTV4TK3f3jTd1Bq9YluS+dIpLStzsW5TpCoz74MuWDBnOuen2FxbWtI7oOyPj78p9L/hFbnnRs0hX17r4SeLuZs4iIbNFCOVWYCd2mRUQkECEMCc5Ui7j4UUREWga1VEREAhHC6K1MqaiIiASiVn0qIiISl5ZfUlRURESCoY56ERGRJGqpiIgEIh9aKioqIiKB0MWPIiISG7VUREQkNvlwnYo66kVEJDZqqYiIBEJ9KiIiEpts9KmY2WLgK6AGqHb33mbWCRgHdAMWA0Pd/fN0tq/TXyIigcjiM+oPc/ee7t47mr8YmOLuuwBTovm0qKUiIkE/CKvHtsVNrySZGgL0i16PAV4ELkpnQ2qpiIgEohbPeNoEDkw2szfMbFi0rLO7b/jLYhnQOd2fQS0VEZFAxDGkOCoUw5IWjXL3UUnzB7t7hZntBDxvZvO/kcHdzSztICoqIiKBiOPW91EBGZXi/Yro38vN7EmgL/CxmRW5+1IzKwKWp7t/nf4SEQmEx/BPKmbW3sy23vAaOAqYAzwNnBGtdgYwId2fQS0VEZEtR2fgSTODuu//se7+nJnNBMab2VnAh8DQdHegoiIiEojmfvKju38A7NPA8s+A/nHsQ0VFRCQQ+XDvLxUVEZFA6Bn1IiISm3xoqWj0l4iIxEYtFRGRQOj0l4iIxCYfTn+pqIiIBMK9NtcRMqY+FRERiY1aKiIigcjGQ7qaW4tpqQw4qh9z50xjftl0Rlx4bq7jfIOypUfZ0hdyvtCyJRIJnpjyIP/34C0AXH3rZTw19SEmvDiW2+69jnbtt8pxwv/I4kO6mo01d4hWrUsy3kEikWDe3JcZOOgUysuXMuPVSZx62jnMm/d+HBGVTdlaTDYIO19zZMv0IV1nnv1j9tpnDzps3Z6zT72A9h3as3rVagAuvmo4n33yOXf/eUza25+/fKZlFDBJaae9Mv6+LF8xJ7Y86UjZUjGzX5lZ12yFaUzfPr1YuHAxixZ9RFVVFePHT+DYwQNyHQtQtnQpW/pCzhdats5FO3HoEQfz6EP/uenuhoIC0KZtm6BGXOVDS6Wp019/AF4zs5fN7Bwz2zEboTZWXNKFJeWV9fPlFUspLu6Siyj/RdnSo2zpCzlfaNkuufoCbrrqdrz2m6Oqrr3tcqbPfY5v9+jGg/eMy1G6/NRUUfkAKKWuuOwHlJnZc2Z2xoZ78jfEzIaZ2Swzm1Vbu7qx1UREmk2/Iw/ms08/Z+478//rvUvOv4pD9h7EwvcXM2jIUTlI17Ba94ynXGuqqLi717r7ZHc/CygG7gQGUldwGvvQKHfv7e69E4n2GYesrFhG19L/nFctLSmisnJZxtuNg7KlR9nSF3K+kLLt23cfDh/wA6bMmsDNo65l/4P7cMOdV9W/X1tby6QnJ3PUMYflJF9DmvshXdnQVFH5RoePu1e5+9PufgrwreaL9U0zZ82mR4/udOvWlcLCQoYOHcIzEydna/cpKVt6lC19IecLKdst19xBv57H0L/3EH4z7BJemz6TEedczs7dS+vXOXzgIXyw4MOc5GtIPvSpNHWdykmNveHua2LO0qiamhrOH34Zk54dS0Eiwf1jxlFW9l62dp+SsqVH2dIXcr6QswGYGdf9+Uo6dGgPZvy77H2uvPC6XMeqlw/XqbSIIcUisuXKdEhxc4tzSPGOHXfL+Pvyky//ndMhxbqiXkQkECGcvsqUioqISCBCGL2VKRUVEZFA5ENLpcXc+0tERMKnloqISCDyYfSXioqISCDy4fSXioqISCDUUS8iIrEJ4TYrmVJHvYiIxEYtFRGRQOj0l4iIxEYd9SIiEhv1qYiISGyycet7MxtoZv82swVmdnHcP4OKiojIFsLMCoA7gKOBPYFTzGzPOPeh018iIoHIQp9KX2CBu38AYGaPAEOAsrh2oJaKiEggPIapCSXAkqT58mhZbJq9pVK9viLWB8aY2TB3HxXnNuOibOlRtvQoW3pCzhbH96WZDQOGJS0alc2ftyW2VIY1vUrOKFt6lC09ypaekLNlzN1HuXvvpCm5oFQAXZPmS6NlsWmJRUVERNIzE9jFzLqbWWvgZODpOHegjnoRkS2Eu1eb2XnAP4AC4D53nxvnPlpiUQnyXGhE2dKjbOlRtvSEnK3ZufskYFJzbd/y4bYAIiISBvWpiIhIbFpMUWnuWwtkwszuM7PlZjYn11mSmVlXM5tqZmVmNtfMzs91pmRm1tbMXjezt6N8I3OdKZmZFZjZW2Y2MddZNmZmi83sXTObbWazcp0nmZlta2aPmdl8M5tnZgfmOhOAme0WHa8N00ozG57rXPmmRZz+im4t8B5wJHUX68wETnH32K4CzYSZHQKsAh5w971ynWcDMysCitz9TTPbGngDOC6g42ZAe3dfZWaFwHTgfHefkeNoAJjZBUBvYBt3PybXeZKZ2WKgt7t/mussGzOzMcDL7n5PNMKonbt/ketcyaLvlApgf3f/MNd58klLaanU31rA3dcDG24tEAR3nwasyHWOjbn7Und/M3r9FTCPmK+ezYTXWRXNFkZTEH/lmFkp8EPgnlxnaUnMrCNwCHAvgLuvD62gRPoDC1VQ4tdSikqz31og35lZN6AX8Fpuk3xTdIppNrAceN7dQ8n3J2AEUJvrII1wYLKZvRFdQR2K7sAnwOjo1OE9ZtY+16EacDLwcK5D5KOWUlQkA2bWAXgcGO7uK3OdJ5m717h7T+qu7O1rZjk/fWhmxwDL3f2NXGdJ4WB335e6u82eG52CDUErYF/gLnfvBawGQusDbQ0cCzya6yz5qKUUlWa/tUC+ivoqHgcecvcncp2nMdEpkqnAwFxnAQ4Cjo36LR4BDjezB3Mb6ZvcvSL693LgSepOEYegHChPanE+Rl2RCcnRwJvu/nGug+SjllJUmv3WAvko6gi/F5jn7rfkOs/GzGxHM9s2er0VdQMx5uc2Fbj779y91N27Ufff2gvufmqOY9Uzs/bRwAuiU0tHAUGMPHT3ZcASM9stWtSfGG+rHpNT0KmvZtMirqjPxq0FMmFmDwP9gB3MrBy4wt3vzW0qoO4v7tOAd6N+C4BLoitqQ1AEjIlG4iSA8e4e3PDdAHUGnqz7m4FWwFh3fy63kb7hl8BD0R+AHwA/zXGeelERPhL4ea6z5KsWMaRYRERahpZy+ktERFoAFRUREYmNioqIiMRGRUVERGKjoiIiIrFRURERkdioqIiISGxUVEREJDb/H07FuWzOEsR4AAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 504x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Classification Report\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       1.00      1.00      1.00       420\n",
            "           1       1.00      1.00      1.00       397\n",
            "           2       0.99      1.00      1.00       369\n",
            "           3       1.00      1.00      1.00       178\n",
            "           4       0.97      1.00      0.99        37\n",
            "           5       1.00      1.00      1.00        62\n",
            "           6       1.00      1.00      1.00        42\n",
            "           7       1.00      1.00      1.00        43\n",
            "\n",
            "    accuracy                           1.00      1548\n",
            "   macro avg       1.00      1.00      1.00      1548\n",
            "weighted avg       1.00      1.00      1.00      1548\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNP6aqzc9hE5"
      },
      "source": [
        "# Convert to model for Tensorflow-Lite"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ODjnYyld9hE6"
      },
      "source": [
        "# Save as a model dedicated to inference\n",
        "model.save(model_save_path, include_optimizer=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zRfuK8Y59hE6",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "106250fb-84e1-4ee8-bc8e-dcedb2bfd31c"
      },
      "source": [
        "# Transform model (quantization)\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_quantized_model = converter.convert()\n",
        "\n",
        "open(tflite_save_path, 'wb').write(tflite_quantized_model)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: /tmp/tmpb_379rn7/assets\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: /tmp/tmpb_379rn7/assets\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "7840"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 89
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CHBPBXdx9hE6"
      },
      "source": [
        "## Inference test"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mGAzLocO9hE7"
      },
      "source": [
        "interpreter = tf.lite.Interpreter(model_path=tflite_save_path)\n",
        "interpreter.allocate_tensors()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oQuDK8YS9hE7"
      },
      "source": [
        "# Get I / O tensor\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2_ixAf_l9hE7"
      },
      "source": [
        "interpreter.set_tensor(input_details[0]['index'], np.array([X_test[0]]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "s4FoAnuc9hE7",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "330d050b-a3fb-41ce-fa07-dc7a48d715e7"
      },
      "source": [
        "%%time\n",
        "# Inference implementation\n",
        "interpreter.invoke()\n",
        "tflite_results = interpreter.get_tensor(output_details[0]['index'])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "CPU times: user 117 µs, sys: 7 µs, total: 124 µs\n",
            "Wall time: 132 µs\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vONjp19J9hE8",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9338c9f1-499b-4e10-844d-950075e98a75"
      },
      "source": [
        "print(np.squeeze(tflite_results))\n",
        "print(np.argmax(np.squeeze(tflite_results)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[1.0000000e+00 1.7798680e-13 5.1755124e-26 1.1783223e-18 1.0188586e-18\n",
            " 1.1212574e-26 2.5901723e-24 2.6419864e-15]\n",
            "0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2GCHB0ELGs60"
      },
      "source": [
        "## Download model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a2jM8I3jGdF6",
        "outputId": "bd2c08cc-d57c-4589-e4fa-f80389176b96"
      },
      "source": [
        "!zip -r model.zip keypoint_classifier  "
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "updating: keypoint_classifier/ (stored 0%)\n",
            "updating: keypoint_classifier/keypoint_classifier.tflite (deflated 22%)\n",
            "updating: keypoint_classifier/keypoint_classifier.hdf5 (deflated 60%)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEqyprC6cMVF"
      },
      "source": [
        "\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "# ❗️Hyperparameters Tuning"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KrjjSqlLcQ4O"
      },
      "source": [
        "%load_ext tensorboard\n",
        "from tensorboard.plugins.hparams import api as hp"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YO6VffaRcY3R"
      },
      "source": [
        "# Init parameters to tune\n",
        "HP_NUM_UNITS_1 = hp.HParam('num_units_1', hp.Discrete([16, 32, 64]))\n",
        "HP_NUM_UNITS_2 = hp.HParam('num_units_2', hp.Discrete([8, 16, 32]))\n",
        "HP_NUM_UNITS_3 = hp.HParam('num_units_3', hp.Discrete([8, 16, 32]))\n",
        "HP_DROPOUT = hp.HParam('dropout', hp.RealInterval(0.0, 0.2))\n",
        "HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))\n",
        "\n",
        "METRIC_ACCURACY = 'accuracy'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eMysvzFxctmG"
      },
      "source": [
        "with tf.summary.create_file_writer('logs/hparam_tuning').as_default():\n",
        "  hp.hparams_config(\n",
        "    hparams=[HP_NUM_UNITS_1,HP_NUM_UNITS_2,HP_NUM_UNITS_3, HP_DROPOUT, HP_OPTIMIZER],\n",
        "    metrics=[hp.Metric(METRIC_ACCURACY, display_name='Accuracy')],\n",
        "  )"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XSDYqpKZgOBw"
      },
      "source": [
        "# Model checkpoint callback\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
        "    model_save_path, verbose=1, save_weights_only=False, save_best_only=True)\n",
        "# Callback for early stopping\n",
        "es_callback = tf.keras.callbacks.EarlyStopping(patience=15, verbose=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8ssEmXpJcxSp"
      },
      "source": [
        "def train_test_model(hparams):\n",
        "\n",
        "  model = tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Input((21 * 2, )),\n",
        "    tf.keras.layers.Dropout(hparams[HP_DROPOUT]),\n",
        "    tf.keras.layers.Dense(hparams[HP_NUM_UNITS_1], activation='relu'),\n",
        "    tf.keras.layers.Dropout(hparams[HP_DROPOUT]),\n",
        "    tf.keras.layers.Dense(hparams[HP_NUM_UNITS_2], activation='relu'),\n",
        "    tf.keras.layers.Dropout(hparams[HP_DROPOUT]),\n",
        "    tf.keras.layers.Dense(hparams[HP_NUM_UNITS_3], activation='relu'),\n",
        "    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')\n",
        "])\n",
        "  \n",
        "  model.compile(\n",
        "    optimizer=hparams[HP_OPTIMIZER],\n",
        "    loss='sparse_categorical_crossentropy',\n",
        "    metrics=['accuracy']\n",
        ")\n",
        "  \n",
        "  cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
        "    model_save_path, verbose=1, save_weights_only=False, save_best_only=True)\n",
        "\n",
        "  model.fit(\n",
        "    X_train,\n",
        "    y_train,\n",
        "    epochs=50,\n",
        "    batch_size=64,\n",
        "    validation_data=(X_test, y_test),\n",
        "    callbacks=[\n",
        "               cp_callback,\n",
        "               es_callback,\n",
        "               ]\n",
        "  ) \n",
        "\n",
        "  # Load model with best accuracy\n",
        "  model = tf.keras.models.load_model(model_save_path)\n",
        "\n",
        "  _, accuracy = model.evaluate(X_test, y_test)\n",
        "  return accuracy\n",
        "\n",
        "def run(run_dir, hparams):\n",
        "  with tf.summary.create_file_writer(run_dir).as_default():\n",
        "    hp.hparams(hparams)  # record the values used in this trial\n",
        "    accuracy = train_test_model(hparams)\n",
        "    tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5pO9W84DdHVL"
      },
      "source": [
        "session_num = 0\n",
        "\n",
        "for num_units_1 in HP_NUM_UNITS_1.domain.values:\n",
        "  for num_units_2 in HP_NUM_UNITS_2.domain.values:\n",
        "    for num_units_3 in HP_NUM_UNITS_3.domain.values:\n",
        "      for dropout_rate in np.arange(HP_DROPOUT.domain.min_value, HP_DROPOUT.domain.max_value, 0.1):\n",
        "        for optimizer in HP_OPTIMIZER.domain.values:\n",
        "          hparams = {\n",
        "              HP_NUM_UNITS_1: num_units_1,\n",
        "              HP_NUM_UNITS_2: num_units_2,\n",
        "              HP_NUM_UNITS_3: num_units_3,\n",
        "              HP_DROPOUT: dropout_rate,\n",
        "              HP_OPTIMIZER: optimizer,\n",
        "          }\n",
        "          run_name = \"run-%d\" % session_num\n",
        "          print('--- Starting trial: %s' % run_name)\n",
        "          print({h.name: hparams[h] for h in hparams})\n",
        "          run('logs/hparam_tuning/' + run_name, hparams)\n",
        "          session_num += 1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "21knbMYldaUn"
      },
      "source": [
        "# !ATTENTION! Works only in Colab\n",
        "%tensorboard --logdir logs/hparam_tuning"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J6_UH6jttBsD"
      },
      "source": [
        "!rm -rf logs"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}